#include <math.h>
#include <stdlib.h>
#include <assert.h>
#include <mpi.h>
#include <map>
#include "esmd_types.h"
#include "cell.h"
#include "bonded.h"
#include "memory_cpp.hpp"
extern "C"{
#include "hilbert.h"
}

#include "comm.h"
void build_cells(cellgrid_t *grid, int nn, real rcut, real skin, box<real> * gbox, box<real> * lbox,
                 int natoms, double (*x)[3], real *q, int *t, real *mass,
                 bond_graph_t *graph,
                 impr_index_t *imidx,
                 long (*excls)[MAX_EXCLS_ATOM], long (*scals)[MAX_SCALS_ATOM], long (*chain2)[MAX_CHAIN2_ATOM][2], int overflow) {
  grid->nn = nn;
  grid->rcut = rcut;
  /*len of lbox*/
  real lx = lbox->hi.x - lbox->lo.x;
  real ly = lbox->hi.y - lbox->lo.y;
  real lz = lbox->hi.z - lbox->lo.z;

  vec<real> glen = gbox->hi - gbox->lo;
  vec<real> llen = lbox->hi - lbox->lo;
  /*ncells needed*/
  real lc = (rcut + skin) / nn;
  vec<int> nlocal;
  nlocal = (llen / lc).floor();

  grid->lbox = *lbox;
  grid->gbox = *gbox;
  /*cell length*/
  grid->len = llen / nlocal;
  grid->rlen = nlocal / llen;

  /*safety distance to move*/
  grid->skin = (grid->len * nn - rcut) * 0.5;
  /*save nlocal, nall*/
  grid->nlocal = nlocal;

  grid->nall = grid->nlocal + 2 * nn;
  grid->dim.lo = -nn;
  grid->dim.hi = grid->nlocal + nn;

  /*allocate for cell*/
  int ncell = grid->nall.vol();
  grid->cells = esmd::allocate<celldata_t>(ncell, "cell/data");

  /*calculate cell basis*/
  FOREACH_CELL(grid, i, j, k, cell) {
    cell->basis.x = i * grid->len.x + lbox->lo.x;
    cell->basis.y = j * grid->len.y + lbox->lo.y;
    cell->basis.z = k * grid->len.z + lbox->lo.z;
    cell->natom = 0;
  }

  // int *first_bond = graph->first_bond;
  // long *bonded_tag = graph->bonded_tag;
  // int *first_impr = imidx->first_impr;
  // int(*impr_bid)[3] = imidx->impr_bid;
  long tot_atoms = natoms;
  for (long i = 0; i < natoms; i++) {
    vec<real> xi(x[i][0], x[i][1], x[i][2]);
    if (!gbox->contains(xi)) {
      while (xi.x >= gbox->hi.x) xi.x -= glen.x;
      while (xi.y >= gbox->hi.y) xi.y -= glen.y;
      while (xi.z >= gbox->hi.z) xi.z -= glen.z;
      while (xi.x < gbox->lo.x) xi.x += glen.x;
      while (xi.y < gbox->lo.y) xi.y += glen.y;
      while (xi.z < gbox->lo.z) xi.z += glen.z;
    }
    if (!lbox->contains(xi)){
      if (overflow == OVF_SKIP){
        tot_atoms --;
	      continue;
      }
      else {
      if (xi.x >= lbox->hi.x) xi.x -= llen.x;
      if (xi.y >= lbox->hi.y) xi.y -= llen.y;
      if (xi.z >= lbox->hi.z) xi.z -= llen.z;
      if (xi.x < lbox->lo.x) xi.x += llen.x;
      if (xi.y < lbox->lo.y) xi.y += llen.y;
      if (xi.z < lbox->lo.z) xi.z += llen.z;

      }
    }
    /*calculate which cell*/
    int cellx = (int)floor((xi.x - lbox->lo.x) * grid->rlen.x);
    int celly = (int)floor((xi.y - lbox->lo.y) * grid->rlen.y);
    int cellz = (int)floor((xi.z - lbox->lo.z) * grid->rlen.z);
    celldata_t *cell = get_cell_xyz(grid, cellx, celly, cellz);

    assert(cell->natom < CELL_CAP);
    int n = cell->natom;
    /*copy atom data to cell*/
    cell->x[n].x = xi.x - cell->basis.x;
    cell->x[n].y = xi.y - cell->basis.y;
    cell->x[n].z = xi.z - cell->basis.z;

    cell->q[n] = q[i];
    cell->t[n] = t[i];
    cell->tag[n] = i;

    cell->mass[n] = mass[i];
    cell->rmass[n] = 1. / mass[i];

    cell->natom++;
  }
  MPI_Allreduce(&tot_atoms, &grid->natoms, 1, MPI_LONG, MPI_SUM, MPI_COMM_WORLD);
}

bool tag_in_list(long tag, long *list, int len) {
  for (int i = 0; i < len; i ++)
    if (tag == list[i])
      return true;
  return false;
}
struct bloom_filter {
  long bitmaps[4] = {0, 0, 0, 0};
  __always_inline void insert(long tag) {
    bitmaps[0] |= 1L << (tag >>  0 & 63L);
    bitmaps[1] |= 1L << (tag >>  4 & 63L);
    bitmaps[2] |= 1L << (tag >>  8 & 63L);
    bitmaps[3] |= 1L << (tag >> 12 & 63L);
  }
  __always_inline bool query(long tag) const {
    return ((bitmaps[0] & 1L << (tag >>  0 & 63L)) != 0) &
     ((bitmaps[1] & 1L << (tag >>  4 & 63L)) != 0) & 
     ((bitmaps[2] & 1L << (tag >>  8 & 63L)) != 0) &
     ((bitmaps[3] & 1L << (tag >> 12 & 63L)) != 0);
  }
};
__always_inline bool is_duplicate(const bloom_filter &filt, long *start, long *end, long tag){
  if (!filt.query(tag)) return false;
  for (long *jptr = start; jptr < end; jptr ++) {
    if (*jptr == tag) return true;
  }
  return false;
}
void part_cellgrid(cellgrid_t *grid, int nn, real rcut, real skin, const box<real> &gbox, const box<real> &lbox) {
  grid->nn = nn;
  grid->rcut = rcut;

  /*len of lbox*/
  real lx = lbox.hi.x - lbox.lo.x;
  real ly = lbox.hi.y - lbox.lo.y;
  real lz = lbox.hi.z - lbox.lo.z;

  vec<real> glen = gbox.hi - gbox.lo;
  vec<real> llen = lbox.hi - lbox.lo;
  /*ncells needed*/
  real lc = (rcut + skin) / nn;
  vec<int> nlocal;
  nlocal = (llen / lc).floor();

  grid->lbox = lbox;
  grid->gbox = gbox;
  /*cell length*/
  grid->len = llen / nlocal;
  grid->rlen = nlocal / llen;

  /*safety distance to move*/
  grid->skin = (grid->len * nn - rcut) * 0.5;
  /*save nlocal, nall*/
  grid->nlocal = nlocal;

  grid->nall = grid->nlocal + 2 * nn;
  grid->dim.lo = -nn;
  grid->dim.hi = grid->nlocal + nn;

  /*allocate for cell*/
  int ncell = grid->nall.vol();
  grid->cells = esmd::allocate<celldata_t>(ncell, "cell/data");

  /*calculate cell basis*/
  FOREACH_CELL(grid, i, j, k, cell) {
    cell->basis.x = i * grid->len.x + lbox.lo.x;
    cell->basis.y = j * grid->len.y + lbox.lo.y;
    cell->basis.z = k * grid->len.z + lbox.lo.z;
    cell->natom = 0;
    // cell->first_bonded[0] = 0;
    // cell->first_impr[0] = 0;
  }
}
void build_cells_bondonly(cellgrid_t *grid, int nn, real rcut, real skin, box<real> * gbox, box<real> * lbox,
                 int natoms, double (*x)[3], real *q, int *t, real *mass,
                 bond_graph_t *graph,
                 impr_index_t *imidx,
                 int overflow) {
  part_cellgrid(grid, nn, rcut, skin, *gbox, *lbox);
  vec<real> glen = grid->gbox.hi - grid->gbox.lo;
  vec<real> llen = grid->lbox.hi - grid->lbox.lo;
  int *first_bond = graph->first_bond;
  long *bonded_tag = graph->bonded_tag;
  int *first_impr = imidx->first_impr;
  int(*impr_bid)[3] = imidx->impr_bid;
  long tot_atoms = natoms;
  for (long i = 0; i < natoms; i++) {
    vec<real> xi(x[i][0], x[i][1], x[i][2]);
    if (!gbox->contains(xi)) {
      while (xi.x >= gbox->hi.x) xi.x -= glen.x;
      while (xi.y >= gbox->hi.y) xi.y -= glen.y;
      while (xi.z >= gbox->hi.z) xi.z -= glen.z;
      while (xi.x < gbox->lo.x) xi.x += glen.x;
      while (xi.y < gbox->lo.y) xi.y += glen.y;
      while (xi.z < gbox->lo.z) xi.z += glen.z;
    }
    if (!lbox->contains(xi)){
      if (overflow == OVF_SKIP){
        tot_atoms --;
	      continue;
      }
      else {
      if (xi.x >= lbox->hi.x) xi.x -= llen.x;
      if (xi.y >= lbox->hi.y) xi.y -= llen.y;
      if (xi.z >= lbox->hi.z) xi.z -= llen.z;
      if (xi.x < lbox->lo.x) xi.x += llen.x;
      if (xi.y < lbox->lo.y) xi.y += llen.y;
      if (xi.z < lbox->lo.z) xi.z += llen.z;

      }
    }
    /*calculate which cell*/
    int cellx = (int)floor((xi.x - lbox->lo.x) * grid->rlen.x);
    int celly = (int)floor((xi.y - lbox->lo.y) * grid->rlen.y);
    int cellz = (int)floor((xi.z - lbox->lo.z) * grid->rlen.z);
    celldata_t *cell = get_cell_xyz(grid, cellx, celly, cellz);

    assert(cell->natom < CELL_CAP);
    int n = cell->natom;
    /*copy atom data to cell*/
    cell->x[n].x = xi.x - cell->basis.x;
    cell->x[n].y = xi.y - cell->basis.y;
    cell->x[n].z = xi.z - cell->basis.z;

    cell->q[n] = q[i];
    cell->t[n] = t[i];
    cell->tag[n] = i;

    cell->mass[n] = mass[i];
    cell->rmass[n] = 1. / mass[i];

    cell->natom++;
  }
  MPI_Allreduce(&tot_atoms, &grid->natoms, 1, MPI_LONG, MPI_SUM, MPI_COMM_WORLD);
}


INLINE void atomcpy(celldata_t *dst_cell, int dst_id, celldata_t *src_cell, int src_id) {
  vec<real> diff_basis;
  vecsubv(diff_basis, src_cell->basis, dst_cell->basis);
  /*new coord in new cell: x[i] + (src_cell->basis - dst_cell->basis)*/
  vec<real> xi;
  vecaddv(xi, src_cell->x[src_id], diff_basis);
  veccpy(dst_cell->x[dst_id], src_cell->x[src_id]);
  veccpy(dst_cell->v[dst_id], src_cell->v[src_id]);
  veccpy(dst_cell->shake_tmp[dst_id], src_cell->shake_tmp[src_id]);
  dst_cell->q[dst_id] = src_cell->q[src_id];
  dst_cell->tag[dst_id] = src_cell->tag[src_id];
  dst_cell->t[dst_id] = src_cell->t[src_id];
  dst_cell->mass[dst_id] = src_cell->mass[src_id];
  // memcpy(dst_cell->shake + dst_id, src_cell->shake + src_id, sizeof(dst_cell->shake[0]));
}
int cell_check(cellgrid_t *grid) {
  int has_leaved = 0;
  FOREACH_LOCAL_CELL(grid, ii, jj, kk, cell) {
    for (int i = 0; i < cell->natom; i++) {
      vec<real> delhi;
      vecsubv(delhi, grid->len, cell->x[i]);
      vec<real> leave_lo, leave_hi;
      /*x[i] + skin < 0 => atom leaves lower*/
      vecaddv(leave_lo, cell->x[i], grid->skin);
      /*len + skin - x[i] < 0 => atom leaves higher*/
      vecaddv(leave_hi, delhi, grid->skin);
      real leave_min = min(min3(leave_lo.x, leave_lo.y, leave_lo.z), min3(leave_hi.x, leave_hi.y, leave_hi.z));
      if (leave_min < 0)
        has_leaved = 1;
    }
  }
  return has_leaved;
}

INLINE int copy_list(long *base, int *first, int dst, int src, int head) {
  first[dst] = head;
  for (int i = first[src]; i < first[src + 1]; i++) {
    base[head++] = base[i];
  }
  return head;
}
INLINE int copy_list2(long (*base)[2], int *first, int dst, int src, int head) {
  first[dst] = head;
  for (int i = first[src]; i < first[src + 1]; i++) {
    base[head][0] = base[i][0];
    base[head][1] = base[i][1];
    head++;
  }
  return head;
}
INLINE int copy_list3i(int (*base)[3], int *first, int dst, int src, int head) {
  first[dst] = head;
  for (int i = first[src]; i < first[src + 1]; i++) {
    base[head][0] = base[i][0];
    base[head][1] = base[i][1];
    base[head][2] = base[i][2];
    head++;
  }
  return head;
}
INLINE int copy_2list(long *basedst, long *basesrc, int *firstdst, int *firstsrc, int dst, int src, int head) {
  for (int i = firstsrc[src]; i < firstsrc[src + 1]; i++) {
    basedst[head++] = basesrc[i];
  }
  return head;
}
INLINE int copy_2list2(long (*basedst)[2], long (*basesrc)[2], int *firstdst, int *firstsrc, int dst, int src, int head) {
  firstdst[dst] = head;
  for (int i = firstsrc[src]; i < firstsrc[src + 1]; i++) {
    basedst[head][0] = basesrc[i][0];
    basedst[head][1] = basesrc[i][1];
    head++;
  }
  return head;
}
INLINE int copy_2list3i(int (*basedst)[3], int (*basesrc)[3], int *firstdst, int *firstsrc, int dst, int src, int head) {
  firstdst[dst] = head;
  for (int i = firstsrc[src]; i < firstsrc[src + 1]; i++) {
    basedst[head][0] = basesrc[i][0];
    basedst[head][1] = basesrc[i][1];
    basedst[head][2] = basesrc[i][2];
    head++;
  }
  return head;
}
// #undef __sw__

#ifdef __sw__
void cell_export_sw(cellgrid_t *grid);
void cell_import_sw(cellgrid_t *grid);
void cell_export(cellgrid_t *grid) {
  cell_export_sw(grid);
}
void cell_import(cellgrid_t *grid) {
  cell_import_sw(grid);
}

#else
void cell_export(cellgrid_t *grid) {
  int tot_bonded_export = 0, tot_chain2_export = 0, tot_scal_export = 0, tot_excl_export = 0, tot_impr_export = 0;
  FOREACH_LOCAL_CELL(grid, ii, jj, kk, cell) {
    int leave[CELL_CAP], stay[CELL_CAP], is_stay[CELL_CAP];
    int nbonded_export = 0, nchain2_export = 0, nscal_export = 0, nexcl_export = 0, natom_export = 0, nimpr_export = 0;

    for (int i = 0; i < cell->natom; i++) {
      vec<real> delhi;
      vecsubv(delhi, grid->len, cell->x[i]);

      real del_min = min(min3(cell->x[i].x, cell->x[i].y, cell->x[i].z), min3(delhi.x, delhi.y, delhi.z) + 1e-8);
      if (del_min < 0) {
        // nbonded_export += cell->first_bonded[i + 1] - cell->first_bonded[i];
        // nchain2_export += cell->first_chain2[i + 1] - cell->first_chain2[i];
        // nscal_export += cell->first_scal_atom[i + 1] - cell->first_scal_atom[i];
        // nexcl_export += cell->first_excl_atom[i + 1] - cell->first_excl_atom[i];
        // nimpr_export += cell->first_impr[i + 1] - cell->first_impr[i];
        natom_export++;
        is_stay[i] = 0;
      } else {
        is_stay[i] = 1;
      }
    }
    int stay_ptr = 0;
    int export_ptr = CELL_CAP - natom_export;
    int bonded_export_ptr = MAX_BONDED_CELL - nbonded_export;
    int chain2_export_ptr = MAX_CHAIN2_CELL - nchain2_export;
    int excl_export_ptr = MAX_EXCL_CELL - nexcl_export;
    int scal_export_ptr = MAX_SCAL_CELL - nscal_export;
    int impr_export_ptr = MAX_IMPR_CELL - nimpr_export;
    int nbonded_stay = 0, nchain2_stay = 0, nscal_stay = 0, nexcl_stay = 0, natom_stay = 0, nimpr_stay = 0;
    for (int i = 0; i < cell->natom; i++) {
      if (is_stay[i]) {
        atomcpy(cell, stay_ptr, cell, i);
        // nbonded_stay = copy_list(cell->bonded_tag, cell->first_bonded, stay_ptr, i, nbonded_stay);
        // nchain2_stay = copy_list2(cell->chain2_tag, cell->first_chain2, stay_ptr, i, nchain2_stay);
        // nexcl_stay = copy_list(cell->excl_tag, cell->first_excl_atom, stay_ptr, i, nexcl_stay);
        // nscal_stay = copy_list(cell->scal_tag, cell->first_scal_atom, stay_ptr, i, nscal_stay);
        // nimpr_stay = copy_list3i(cell->impr_idx, cell->first_impr, stay_ptr, i, nimpr_stay);
        stay_ptr++;
      } else {
        atomcpy(cell, export_ptr, cell, i);

        // bonded_export_ptr = copy_list(cell->bonded_tag, cell->first_bonded, export_ptr, i, bonded_export_ptr);
        // chain2_export_ptr = copy_list2(cell->chain2_tag, cell->first_chain2, export_ptr, i, chain2_export_ptr);
        // excl_export_ptr = copy_list(cell->excl_tag, cell->first_excl_atom, export_ptr, i, excl_export_ptr);
        // scal_export_ptr = copy_list(cell->scal_tag, cell->first_scal_atom, export_ptr, i, scal_export_ptr);
        // impr_export_ptr = copy_list3i(cell->impr_idx, cell->first_impr, export_ptr, i, impr_export_ptr);
        export_ptr++;
      }
    }

    // cell->first_bonded[CELL_CAP] = bonded_export_ptr;
    // cell->first_chain2[CELL_CAP] = chain2_export_ptr;
    // cell->first_scal_atom[CELL_CAP] = scal_export_ptr;
    // cell->first_excl_atom[CELL_CAP] = excl_export_ptr;
    // cell->first_impr[CELL_CAP] = impr_export_ptr;

    // cell->first_bonded[stay_ptr] = nbonded_stay;
    // cell->first_chain2[stay_ptr] = nchain2_stay;
    // cell->first_scal_atom[stay_ptr] = nscal_stay;
    // cell->first_excl_atom[stay_ptr] = nexcl_stay;
    // cell->first_impr[stay_ptr] = nimpr_stay;
    cell->natom = stay_ptr;
    cell->nexport = natom_export;
    // cell->nbonded_export = nbonded_export;
    // cell->nchain2_export = nchain2_export;
    // cell->nexcl_export = nexcl_export;
    // cell->nscal_export = nscal_export;
    // cell->nimpr_export = nimpr_export;
    // tot_bonded_export += nbonded_export;
    // tot_chain2_export += nchain2_export;
    // tot_excl_export += nexcl_export;
    // tot_scal_export += nscal_export;
    // tot_impr_export += nimpr_export;
  }
}
#define SUB_CELL 4
void cell_import(cellgrid_t *grid) {
  int hib[SUB_CELL][SUB_CELL][SUB_CELL];
  gen_hilbert((int*)hib, SUB_CELL);
  celldata_t tmp;
  int total_n_atom = 0;
  int tot_bonded_import = 0, tot_chain2_import = 0, tot_excl_import = 0, tot_scal_import = 0, tot_impr_import = 0;
  FOREACH_LOCAL_CELL(grid, ii, jj, kk, icell) {
    int nimport = 0;

    FOREACH_NEIGHBOR(grid, ii, jj, kk, dx, dy, dz, jcell) {
      for (int j = CELL_CAP - jcell->nexport; j < CELL_CAP; j++) {
        vec<real> diff_basis;
        vecsubv(diff_basis, jcell->basis, icell->basis);
        vec<real> xj;
        vecaddv(xj, jcell->x[j], diff_basis);
        vec<real> delhi;
        vecsubv(delhi, grid->len, xj);
        real del_min = min(min3(xj.x, xj.y, xj.z), min3(delhi.x, delhi.y, delhi.z) - 1e-8);
        if (del_min >= 0) {
          int dst_id = icell->natom + nimport;
          nimport++;
          atomcpy(icell, dst_id, jcell, j);
          veccpy(icell->x[dst_id], xj);

          // icell->first_bonded[dst_id + 1] = copy_2list(icell->bonded_tag, jcell->bonded_tag, icell->first_bonded, jcell->first_bonded, dst_id, j, icell->first_bonded[dst_id]);
          // icell->first_chain2[dst_id + 1] = copy_2list2(icell->chain2_tag, jcell->chain2_tag, icell->first_chain2, jcell->first_chain2, dst_id, j, icell->first_chain2[dst_id]);
          // icell->first_excl_atom[dst_id + 1] = copy_2list(icell->excl_tag, jcell->excl_tag, icell->first_excl_atom, jcell->first_excl_atom, dst_id, j, icell->first_excl_atom[dst_id]);
          // icell->first_scal_atom[dst_id + 1] = copy_2list(icell->scal_tag, jcell->scal_tag, icell->first_scal_atom, jcell->first_scal_atom, dst_id, j, icell->first_scal_atom[dst_id]);
          // icell->first_impr[dst_id + 1] = copy_2list3i(icell->impr_idx, jcell->impr_idx, icell->first_impr, jcell->first_impr, dst_id, j, icell->first_impr[dst_id]);
        }
      }
    }
    // tot_bonded_import += icell->first_bonded[icell->natom + nimport] - icell->first_bonded[icell->natom];
    // tot_chain2_import += icell->first_chain2[icell->natom + nimport] - icell->first_chain2[icell->natom];
    // tot_excl_import += icell->first_excl_atom[icell->natom + nimport] - icell->first_excl_atom[icell->natom];
    // tot_scal_import += icell->first_scal_atom[icell->natom + nimport] - icell->first_scal_atom[icell->natom];
    // tot_impr_import += icell->first_impr[icell->natom + nimport] - icell->first_impr[icell->natom];
    assert(icell->natom + nimport + icell->nexport < CELL_CAP);
    // if (icell->first_bonded[icell->natom + nimport] >= icell->first_bonded[CELL_CAP - icell->nexport]) {
    //   printf("%d %d %d %d\n", ii, jj, kk, icell - grid->cells);
    //   printf("%ld %ld %ld\n", icell->natom, nimport, icell->nexport);
    //   for (int i = 0; i < CELL_CAP + 1; i++) {
    //     printf("%d %d\n", i, icell->first_bonded[i]);
    //   }
    // }
    // assert(icell->first_bonded[icell->natom + nimport] < icell->first_bonded[CELL_CAP - icell->nexport]);
    // assert(icell->first_chain2[icell->natom + nimport] < icell->first_chain2[CELL_CAP - icell->nexport]);
    // assert(icell->first_excl_atom[icell->natom + nimport] < icell->first_excl_atom[CELL_CAP - icell->nexport]);
    // assert(icell->first_scal_atom[icell->natom + nimport] < icell->first_scal_atom[CELL_CAP - icell->nexport]);
    // assert(icell->first_impr[icell->natom + nimport] < icell->first_impr[CELL_CAP - icell->nexport]);

    icell->natom += nimport;
    total_n_atom += icell->natom;

    int bkt[CELL_CAP];
    int head[SUB_CELL*SUB_CELL*SUB_CELL];
    for (int i = 0; i < SUB_CELL*SUB_CELL*SUB_CELL; i ++){
      head[i] = 0;
    }
    for (int i = 0; i < icell->natom; i ++){
      int subx = max(min((int)(icell->x[i].x * grid->rlen.x * SUB_CELL), SUB_CELL), 0);
      int suby = max(min((int)(icell->x[i].y * grid->rlen.y * SUB_CELL), SUB_CELL), 0);
      int subz = max(min((int)(icell->x[i].z * grid->rlen.z * SUB_CELL), SUB_CELL), 0);
      bkt[i] = hib[subx][suby][subz];
      head[bkt[i]] ++;
    }
    for (int i = 1; i < SUB_CELL*SUB_CELL*SUB_CELL; i ++) {
      head[i] = head[i - 1] + head[i];
    }

    for (int i = SUB_CELL*SUB_CELL*SUB_CELL - 1; i > 0; i --) {
      head[i] = head[i - 1];
    }
    head[0] = 0;
    int reorder[CELL_CAP];
    for (int i = 0; i < icell->natom; i ++) {
      reorder[head[bkt[i]] ++] = i;
    }

    veccpy(tmp.basis, icell->basis);
    // tmp.first_bonded[0] = 0;
    // tmp.first_chain2[0] = 0;
    // tmp.first_excl_atom[0] = 0;
    // tmp.first_scal_atom[0] = 0;
    // tmp.first_impr[0] = 0;
    for (int i = 0; i < icell->natom; i ++) {
      atomcpy(&tmp, i, icell, reorder[i]);
      // tmp.first_bonded[i + 1] = copy_2list(tmp.bonded_tag, icell->bonded_tag, tmp.first_bonded, icell->first_bonded, i, reorder[i], tmp.first_bonded[i]);
      // tmp.first_chain2[i + 1] = copy_2list2(tmp.chain2_tag, icell->chain2_tag, tmp.first_chain2, icell->first_chain2, i, reorder[i], tmp.first_chain2[i]);
      // tmp.first_excl_atom[i + 1] = copy_2list(tmp.excl_tag, icell->excl_tag, tmp.first_excl_atom, icell->first_excl_atom, i, reorder[i], tmp.first_excl_atom[i]);
      // tmp.first_scal_atom[i + 1] = copy_2list(tmp.scal_tag, icell->scal_tag, tmp.first_scal_atom, icell->first_scal_atom, i, reorder[i], tmp.first_scal_atom[i]);
      // tmp.first_impr[i + 1] = copy_2list3i(tmp.impr_idx, icell->impr_idx, tmp.first_impr, icell->first_impr, i, reorder[i], tmp.first_impr[i]);
    }

    // icell->first_bonded[0] = 0;
    // icell->first_chain2[0] = 0;
    // icell->first_excl_atom[0] = 0;
    // icell->first_scal_atom[0] = 0;
    // icell->first_impr[0] = 0;

    for (int i = 0; i < icell->natom; i++) {
      atomcpy(icell, i, &tmp, i);
      // icell->first_bonded[i + 1] = copy_2list(icell->bonded_tag, tmp.bonded_tag, icell->first_bonded, tmp.first_bonded, i, i, icell->first_bonded[i]);
      // icell->first_chain2[i + 1] = copy_2list2(icell->chain2_tag, tmp.chain2_tag, icell->first_chain2, tmp.first_chain2, i, i, icell->first_chain2[i]);
      // icell->first_excl_atom[i + 1] = copy_2list(icell->excl_tag, tmp.excl_tag, icell->first_excl_atom, tmp.first_excl_atom, i, i, icell->first_excl_atom[i]);
      // icell->first_scal_atom[i + 1] = copy_2list(icell->scal_tag, tmp.scal_tag, icell->first_scal_atom, tmp.first_scal_atom, i, i, icell->first_scal_atom[i]);
      // icell->first_impr[i + 1] = copy_2list3i(icell->impr_idx, tmp.impr_idx, icell->first_impr, tmp.first_impr, i, i, icell->first_impr[i]);
      icell->rmass[i] = 1.0 / icell->mass[i];
    }
  }
  printf("natom: %d, import: b=%d, c=%d, e=%d, s=%d, i=%d\n", total_n_atom, tot_bonded_import, tot_chain2_import, tot_excl_import, tot_scal_import, tot_impr_import);
  // cell_sort(grid);
}
#endif
#define SUB_CELL 4
void cell_sort(cellgrid_t *grid) {
  // return;
  int hib[SUB_CELL][SUB_CELL][SUB_CELL];
  gen_hilbert((int*)hib, SUB_CELL);
  celldata_t tmp;
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, cell) {
    int bkt[CELL_CAP];
    int head[SUB_CELL * SUB_CELL * SUB_CELL];
    for (int i = 0; i < SUB_CELL * SUB_CELL * SUB_CELL; i++) {
      head[i] = 0;
    }
    for (int i = 0; i < cell->natom; i++) {
      int subx = max(min((int)(cell->x[i].x * grid->rlen.x * SUB_CELL), SUB_CELL), 0);
      int suby = max(min((int)(cell->x[i].y * grid->rlen.y * SUB_CELL), SUB_CELL), 0);
      int subz = max(min((int)(cell->x[i].z * grid->rlen.z * SUB_CELL), SUB_CELL), 0);
      bkt[i] = hib[subx][suby][subz];
      head[bkt[i]]++;
    }
    for (int i = 1; i < SUB_CELL * SUB_CELL * SUB_CELL; i++) {
      head[i] = head[i - 1] + head[i];
    }

    for (int i = SUB_CELL * SUB_CELL * SUB_CELL - 1; i > 0; i--) {
      head[i] = head[i - 1];
    }
    head[0] = 0;
    int reorder[CELL_CAP];
    for (int i = 0; i < cell->natom; i++) {
      reorder[head[bkt[i]]++] = i;
    }

    veccpy(tmp.basis, cell->basis);

    for (int i = 0; i < cell->natom; i++) {
      atomcpy(cell, i, &tmp, i);
      cell->rmass[i] = 1.0 / cell->mass[i];
    }
  }
}
